#!/usr/bin/env python
# Copyright 2020 The PyTorch Lightning team and Microsoft Corporation. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py.

This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
the future. Once extracted, the weights don't require DeepSpeed and can be used in any
application. Additionally the script has been modified to ensure we keep the lightning state inside the state dict
for being able to run Model.load_from_checkpoint('...').

Example usage within the Lightning checkpoint directory where 'latest' is found:

>>> from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict # doctest: +SKIP

# Lightning deepspeed has saved a directory instead of a file

>>> save_path = "lightning_logs/version_0/checkpoints/epoch=0-step=0.ckpt/" # doctest: +SKIP
>>> output_path = "lightning_model.pt" # doctest: +SKIP
>>> convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) # doctest: +SKIP
Saving fp32 state dict to lightning_model.pt
"""

import os

import torch

from pytorch_lightning.utilities import _DEEPSPEED_AVAILABLE

if _DEEPSPEED_AVAILABLE:
    from deepspeed.utils.zero_to_fp32 import (
        get_fp32_state_dict_from_zero_checkpoint,
        get_model_state_file,
        get_optim_files,
    )

CPU_DEVICE = torch.device("cpu")


def ds_checkpoint_dir(checkpoint_dir: str, tag: str = None):
    if tag is None:
        latest_path = os.path.join(checkpoint_dir, "latest")
        if os.path.isfile(latest_path):
            with open(latest_path) as fd:
                tag = fd.read().strip()
        else:
            raise ValueError(f"Unable to find 'latest' file at {latest_path}")

    directory = os.path.join(checkpoint_dir, tag)

    if not os.path.isdir(directory):
        raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
    return directory


def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir: str, output_file: str, tag: str = None):
    """
    Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
    loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
    Args:
        - ``checkpoint_dir``: path to the desired checkpoint folder.
                (one that contains the tag-folder, like ``global_step14``)
        - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
        - ``tag``: checkpoint tag used as a unique identifier for checkpoint.
                If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder,
                    e.g., ``global_step14``
    """

    state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)

    # additional logic to ensure we keep the lightning state dict as well from rank 0.
    deepspeed_states = [
        "module",
        "optimizer",
        "lr_scheduler",
        "csr_tensor_module_names",
        "skipped_steps",
        "global_steps",
        "dp_world_size",
        "mp_world_size",
    ]
    checkpoint_dir = ds_checkpoint_dir(checkpoint_dir)
    optim_files = get_optim_files(checkpoint_dir)
    optim_state = torch.load(optim_files[0], map_location=CPU_DEVICE)
    zero_stage = optim_state["optimizer_state_dict"]["zero_stage"]
    model_file = get_model_state_file(checkpoint_dir, zero_stage)
    client_state = torch.load(model_file, map_location=CPU_DEVICE)
    client_state = {key: value for key, value in client_state.items() if key not in deepspeed_states}
    # State dict keys will include reference to wrapper LightningDeepSpeedModule
    # Delete `module` prefix before saving.
    state_dict = {k.partition("module.")[2]: state_dict[k] for k in state_dict.keys()}
    client_state["state_dict"] = state_dict

    print(f"Saving fp32 state dict to {output_file}")
    torch.save(client_state, output_file)
